/*
 * Copyright (c) 2023 Huawei Device Co., Ltd.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "card_channel.h"

#include <securec.h>
#include <stddef.h>

#include "apdu_core_defines.h"
#include "apdu_utils.h"
#include "card_channel.h"
#include "command_apdu_inner.h"
#include "logger.h"
#include "response_apdu_inner.h"

#define DEFAULT_SCP_SUPPORT_FIELD_INDEX 25

#ifndef DEFAULT_TRANSMIT_TIMEOUT_LEN
#define DEFAULT_TRANSMIT_TIMEOUT_LEN 5000
#endif

#define IS_FATAL_ERROR(err) (((err) >> 16U) == ERR_FATAL)

CardChannel *ConstructCardChannel(SecureElementContext *context, ChannelOperations *oper)
{
    if (oper == NULL) {
        return NULL;
    }

    CardChannel *channel = malloc(sizeof(CardChannel));
    if (channel == NULL) {
        LOG_ERROR("malloc memory for card channel error");
        return NULL;
    }

    (void)memset_s(channel, sizeof(CardChannel), 0, sizeof(CardChannel));
    channel->context = context;
    channel->oper = *oper;

    return channel;
}

void DestructCardChannel(CardChannel *channel)
{
    if (channel == NULL) {
        return;
    }

    (void)memset_s(channel, sizeof(CardChannel), 0, sizeof(CardChannel));
    free(channel);
}

ResultCode ChannelTransmitApdu(const CardChannel *channel, CommandApdu *commandApdu, ResponseApdu **responseApdu)
{
    if (channel == NULL || channel->oper.transmit == NULL) {
        LOG_ERROR("invalid input channel");
        return INVALID_PARA_NULL_PTR;
    }

    if (commandApdu == NULL || responseApdu == NULL || *responseApdu != NULL) {
        LOG_ERROR("invalid input commandApdu");
        return INVALID_PARA_NULL_PTR;
    }

    ResponseApdu *output = CreateResponseApdu(MAX_APDU_DATA_SIZE);
    if (output == NULL) {
        LOG_ERROR("malloc for responseApdu error");
        return MEM_ALLOC_ERR;
    }

    ResultCode result = ERR_GENERIC_ERR;
    do {
        uint32_t apduBufferSize = 0;
        const uint8_t *apduBuffer = GetApduCommandBuffer(commandApdu, &apduBufferSize);
        if (apduBuffer == NULL || apduBufferSize == 0) {
            LOG_ERROR("get apdu buffer error");
            result = MEM_READ_ERR;
            break;
        }

        result = channel->oper.transmit(channel->context, apduBuffer, apduBufferSize, output->data, &output->length);
        if (result != SUCCESS) {
            LOG_ERROR("channel transmit error = 0x%x", result);
            break;
        }

        *responseApdu = output;
        result = SUCCESS;
    } while (0);

    if (result != SUCCESS) {
        DestroyResponseApdu(output);
    }

    return result;
}

ResultCode ChannelOpenChannel(const CardChannel *channel)
{
    if (channel == NULL) {
        LOG_ERROR("invalid card channel");
        return CHN_OPEN_ERR;
    }

    if (channel->oper.open == NULL) {
        // no need to open
        return SUCCESS;
    }

    return channel->oper.open(channel->context);
}

ResultCode ChannelCloseChannel(const CardChannel *channel)
{
    if (channel == NULL) {
        LOG_ERROR("invalid card channel");
        return CHN_CLOSE_ERR;
    }

    (void)ChannelCloseChannelSecure(channel);
    if (channel->oper.close == NULL) {
        // no need to close
        return SUCCESS;
    }

    return channel->oper.close(channel->context);
}

ResultCode ChannelOpenChannelSecure(const CardChannel *channel, const uint8_t *key, uint32_t keyLen, uint8_t keyVersion,
    uint8_t keyId)
{
    if (channel == NULL) {
        LOG_ERROR("invalid card channel");
        return CHN_CLOSE_ERR;
    }

    if (!channel->secureOption.enable) {
        return SUCCESS;
    }

    ScpEnableChecker *checker = channel->secureOption.checker;
    if (checker != NULL && !checker(channel)) {
        LOG_INFO("secureOption is enable but not check success");
        return SUCCESS;
    }

    if (channel->oper.openSecure == NULL) {
        // no need to open
        return SUCCESS;
    }
    return channel->oper.openSecure(channel->context, key, keyLen, keyVersion, keyId);
}

ResultCode ChannelCloseChannelSecure(const CardChannel *channel)
{
    if (channel == NULL) {
        LOG_ERROR("invalid card channel");
        return CHN_CLOSE_ERR;
    }

    if (!channel->secureOption.enable) {
        return SUCCESS;
    }

    if (channel->oper.closeSecure == NULL) {
        // no need to close
        return SUCCESS;
    }
    return channel->oper.closeSecure(channel->context);
}

ResultCode ChannelGetOpenChannelResponse(const CardChannel *channel, uint8_t *response, uint32_t *responseLen)
{
    if (channel == NULL || response == NULL || responseLen == NULL) {
        LOG_ERROR("invalid card channel");
        return INVALID_PARA_NULL_PTR;
    }

    if (channel->oper.getOpenResponse == NULL) {
        return INVALID_PARA_NULL_PTR;
    }

    return channel->oper.getOpenResponse(channel->context, response, responseLen);
}

bool NeedNextLoopTransmit(uint32_t currLoop, uint32_t maxLoop, uint64_t startTick, ResultCode retCode)
{
    uint64_t curr = GetCurrentTimeMillis();
    uint64_t cost = (curr > startTick) ? (curr - startTick) : 0;
    if (cost >= DEFAULT_TRANSMIT_TIMEOUT_LEN || IS_FATAL_ERROR(retCode)) {
        LOG_INFO("break current operate for cost is %u, retCode is 0x%x", (uint32_t)cost, retCode);
        return false;
    }

    return (currLoop <= maxLoop);
}

ResultCode ChannelOpenTransmitCloseWithRetry(const CardChannel *channel, uint32_t retry, ResponseChecker checker,
    CommandApdu *commandApdu, ResponseApdu **responseApdu)
{
    if (channel == NULL || commandApdu == NULL || responseApdu == NULL) {
        LOG_ERROR("invalid input");
        return INVALID_PARA_NULL_PTR;
    }
    uint64_t startTick = GetCurrentTimeMillis();
    ResultCode ret = ERR_GENERIC_ERR;
    for (uint32_t loop = 1; NeedNextLoopTransmit(loop, retry, startTick, ret); loop++) {
        ret = ChannelOpenChannel(channel);
        if (ret != SUCCESS) {
            LOG_ERROR("channel open error = 0x%x, retry = %u/%u", ret, loop, retry);
            continue;
        }

        const SecureOption *opt = &channel->secureOption;
        if (opt->enable) {
            ret = ChannelOpenChannelSecure(channel, opt->key, opt->keyLen, opt->keyVersion, opt->keyId);
            if (ret != SUCCESS) {
                LOG_ERROR("channel open secure error = 0x%x, retry = %u/%u", ret, loop, retry);
                (void)ChannelCloseChannel(channel);
                continue;
            }
        }

        ret = ChannelTransmitApdu(channel, commandApdu, responseApdu);
        if (ret != SUCCESS) {
            LOG_ERROR("channel transmit error = 0x%x, retry = %u/%u", ret, loop, retry);
            (void)ChannelCloseChannel(channel);
            continue;
        }

        if (checker != NULL) {
            ret = checker(*responseApdu);
            if (ret != SUCCESS) {
                LOG_ERROR("channel check error = 0x%x, retry = %u/%u", ret, loop, retry);
                DestroyResponseApdu(*responseApdu);
                *responseApdu = NULL;
                (void)ChannelCloseChannel(channel);
                continue;
            }
        }

        (void)ChannelCloseChannel(channel);
        break;
    }
    return ret;
}

ResultCode ChannelEnableSecureChannelProtocol(CardChannel *channel, const uint8_t *key, uint32_t keyLen,
    uint8_t keyVersion, uint8_t keyId)
{
    if (channel == NULL) {
        return INVALID_PARA_NULL_PTR;
    }

    SecureOption *option = &channel->secureOption;
    if (memcpy_s(option->key, SECURE_OPTION_KEY_MAX_SIZE, key, keyLen) != EOK) {
        return MEM_COPY_ERR;
    }
    option->keyLen = keyLen;
    option->keyVersion = keyVersion;
    option->keyId = keyId;
    option->enable = 1;
    return SUCCESS;
}

ResultCode ChannelSetSecureChannelProtocolChecker(CardChannel *channel, ScpEnableChecker checker)
{
    if (channel == NULL) {
        return INVALID_PARA_NULL_PTR;
    }

    SecureOption *option = &channel->secureOption;
    option->checker = checker;
    return SUCCESS;
}

ResultCode DefaultResponseChecker(const ResponseApdu *apdu)
{
    if (apdu == NULL) {
        return INVALID_PARA_NULL_PTR;
    }

    uint16_t sw = 0;
    if (!GetStatusWords(apdu, &sw)) {
        LOG_ERROR("GetStatusWords failed");
        return RES_APDU_SW_ERR;
    }

    if (sw != SW_NO_ERROR) {
        LOG_ERROR("CheckStatusWords failed, sw is 0x%x", (uint32_t)sw);
        return (uint32_t)sw | (RES_APDU_SW_ERR << 16U); // shift 16 to make the sw returned
    }

    return SUCCESS;
}

ResultCode DefaultScpEnableChecker(const CardChannel *channel)
{
    if (channel == NULL) {
        return INVALID_PARA_NULL_PTR;
    }
    uint8_t response[MAX_APDU_RESP_SIZE] = {0};
    uint32_t responseLen = MAX_APDU_RESP_SIZE;

    uint32_t ret = ChannelGetOpenChannelResponse(channel, response, &responseLen);
    if (ret != SUCCESS) {
        LOG_ERROR("ChannelGetOpenChannelResponse failed, ret is %u", ret);
        return ret;
    }

    if (responseLen < DEFAULT_SCP_SUPPORT_FIELD_INDEX + 1) {
        LOG_ERROR("ChannelGetOpenChannelResponse err, responseLen is %u", responseLen);
        return ret;
    }

    uint8_t support = response[DEFAULT_SCP_SUPPORT_FIELD_INDEX];
    LOG_INFO("ChannelGetOpenChannelResponse support is %u", (uint32_t)support);
    return ((support & 1) == 1);
}